{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d0f7c47",
   "metadata": {},
   "outputs": [],
   "source": [
    "# type: ignore"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Google Vertex Supervised Fine-Tuning\n",
    "\n",
    "This recipe allows TensorZero users to fine-tune Gemini models using their own data.\n",
    "Since TensorZero automatically logs all inferences and feedback, it is straightforward to fine-tune a model using your own data and any prompt you want.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To get started:\n",
    "\n",
    "- Set the `TENSORZERO_CLICKHOUSE_URL` environment variable. For example: `TENSORZERO_CLICKHOUSE_URL=\"http://chuser:chpassword@localhost:8123/tensorzero\"`\n",
    "- Set the `GCP_VERTEX_CREDENTIALS_PATH`, `GCP_PROJECT_ID`, `GCP_LOCATION`, and `GCP_BUCKET_NAME` environment variables.\n",
    "- Create local authentication credentials `gcloud auth application-default login`\n",
    "- You may need to [Create a Bucket](https://cloud.google.com/storage/docs/creating-buckets) on GCP, if you do not already have one.\n",
    "- Update the following parameters:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG_PATH = \"../../../examples/data-extraction-ner/config/tensorzero.toml\"\n",
    "\n",
    "FUNCTION_NAME = \"extract_entities\"\n",
    "\n",
    "METRIC_NAME = \"jaccard_similarity\"\n",
    "\n",
    "# The name of the variant to use to grab the templates used for fine-tuning\n",
    "TEMPLATE_VARIANT_NAME = \"gpt_4o_mini\"\n",
    "\n",
    "# If the metric is a float metric, you can set the threshold to filter the data\n",
    "FLOAT_METRIC_THRESHOLD = 0.5\n",
    "\n",
    "# Fraction of the data to use for validation\n",
    "VAL_FRACTION = 0.2\n",
    "\n",
    "# Maximum number of samples to use for fine-tuning\n",
    "MAX_SAMPLES = 100_000\n",
    "\n",
    "# The name of the model to fine-tune (supported models: https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning)\n",
    "MODEL_NAME = \"gemini-2.0-flash-lite-001\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "tensorzero_path = os.path.abspath(os.path.join(os.getcwd(), \"../../../\"))\n",
    "if tensorzero_path not in sys.path:\n",
    "    sys.path.append(tensorzero_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import tempfile\n",
    "import time\n",
    "import warnings\n",
    "from typing import Any, Dict, List, Optional\n",
    "\n",
    "import toml\n",
    "import vertexai\n",
    "from google.cloud import storage\n",
    "from google.cloud.aiplatform_v1.types import JobState\n",
    "from IPython.display import clear_output\n",
    "from tensorzero import (\n",
    "    FloatMetricFilter,\n",
    "    RawText,\n",
    "    TensorZeroGateway,\n",
    "    Text,\n",
    "    Thought,\n",
    "    ToolCall,\n",
    "    ToolResult,\n",
    ")\n",
    "from tensorzero.util import uuid7\n",
    "from vertexai.tuning import sft\n",
    "\n",
    "from recipes.util import train_val_split"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize Vertex AI\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vertexai.init(project=os.environ[\"GCP_PROJECT_ID\"], location=os.environ[\"GCP_LOCATION\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the TensorZero client\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorzero_client = TensorZeroGateway.build_embedded(\n",
    "    config_file=CONFIG_PATH,\n",
    "    clickhouse_url=os.environ[\"TENSORZERO_CLICKHOUSE_URL\"],\n",
    "    timeout=15,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set the metric filter\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "comparison_operator = \">=\"\n",
    "metric_node = FloatMetricFilter(\n",
    "    metric_name=METRIC_NAME,\n",
    "    value=FLOAT_METRIC_THRESHOLD,\n",
    "    comparison_operator=comparison_operator,\n",
    ")\n",
    "# from tensorzero import BooleanMetricFilter\n",
    "# metric_node = BooleanMetricFilter(\n",
    "#     metric_name=METRIC_NAME,\n",
    "#     value=True  # or False\n",
    "# )\n",
    "\n",
    "metric_node"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Query the inferences from ClickHouse.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stored_inferences = tensorzero_client.experimental_list_inferences(\n",
    "    function_name=FUNCTION_NAME,\n",
    "    variant_name=None,\n",
    "    output_source=\"inference\",  # could also be \"demonstration\"\n",
    "    filters=metric_node,\n",
    "    limit=MAX_SAMPLES,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Render the stored inferences\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rendered_samples = tensorzero_client.experimental_render_samples(\n",
    "    stored_inferences=stored_inferences,\n",
    "    variants={FUNCTION_NAME: TEMPLATE_VARIANT_NAME},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Split the data into training and validation sets for fine-tuning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_samples, val_samples = train_val_split(\n",
    "    rendered_samples,\n",
    "    val_size=VAL_FRACTION,\n",
    "    last_inference_only=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert inferences to vertex format\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "role_map = {\n",
    "    \"user\": \"user\",\n",
    "    \"assistant\": \"model\",\n",
    "    \"system\": \"system\",\n",
    "}\n",
    "\n",
    "\n",
    "def merge_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:\n",
    "    \"\"\"\n",
    "    Merge consecutive messages with the same role into a single message.\n",
    "    \"\"\"\n",
    "    merged: List[Dict[str, Any]] = []\n",
    "    for msg in messages:\n",
    "        role = msg[\"role\"]\n",
    "        parts = msg.get(\"parts\", [])\n",
    "        if merged and merged[-1][\"role\"] == role:\n",
    "            merged[-1][\"parts\"].extend(parts)\n",
    "        else:\n",
    "            merged.append({\"role\": role, \"parts\": list(parts)})\n",
    "    return merged\n",
    "\n",
    "\n",
    "def render_chat_message(\n",
    "    role: str,\n",
    "    content_blocks: List[Any],  # instances of Text, RawText, Thought, ToolCall, ToolResult\n",
    ") -> Optional[Dict[str, Any]]:\n",
    "    \"\"\"\n",
    "    Render a single chat message into Google “parts” format.\n",
    "    \"\"\"\n",
    "    parts: List[Dict[str, Any]] = []\n",
    "    for blk in content_blocks:\n",
    "        # plain text\n",
    "        if isinstance(blk, Text):\n",
    "            parts.append({\"text\": blk.text})\n",
    "        elif isinstance(blk, RawText):  # Verify if needed\n",
    "            parts.append({\"text\": blk.value})\n",
    "        # internal “thoughts”\n",
    "        elif isinstance(blk, Thought):\n",
    "            parts.append({\"text\": f\"<think>{blk.text}</think>\"})\n",
    "        # function call (assistant only)\n",
    "        elif isinstance(blk, ToolCall) and role == \"assistant\":\n",
    "            args = blk.raw_arguments\n",
    "            # raw_arguments might already be a dict or JSON string\n",
    "            if isinstance(args, str):\n",
    "                args = json.loads(args)\n",
    "            parts.append(\n",
    "                {\n",
    "                    \"functionCall\": {\n",
    "                        \"name\": blk.name,\n",
    "                        \"args\": args,\n",
    "                    }\n",
    "                }\n",
    "            )\n",
    "        # function result (user only)\n",
    "        elif isinstance(blk, ToolResult) and role == \"user\":\n",
    "            parts.append(\n",
    "                {\n",
    "                    \"functionResponse\": {\n",
    "                        \"name\": blk.name,\n",
    "                        \"response\": {\"result\": blk.result},\n",
    "                    }\n",
    "                }\n",
    "            )\n",
    "        else:\n",
    "            warnings.warn(\n",
    "                f\"Unsupported block type {type(blk)} in role={role}, skipping inference.\",\n",
    "                UserWarning,\n",
    "            )\n",
    "            return None\n",
    "    return {\"role\": role_map[role], \"parts\": parts}\n",
    "\n",
    "\n",
    "def inference_to_google(\n",
    "    inf,\n",
    ") -> Optional[Dict[str, Any]]:\n",
    "    \"\"\"\n",
    "    Convert a single rendered_inference into the Google Vertex format dict.\n",
    "    \"\"\"\n",
    "    model_input = inf.input\n",
    "    rendered_msgs: List[Dict[str, Any]] = []\n",
    "\n",
    "    # 1) systemInstruction\n",
    "    if model_input.system:\n",
    "        system_instruction = {\n",
    "            \"role\": role_map[\"system\"],\n",
    "            \"parts\": [{\"text\": model_input.system}],\n",
    "        }\n",
    "    else:\n",
    "        system_instruction = None\n",
    "\n",
    "    # 2) all user/assistant messages\n",
    "    for msg in model_input.messages:\n",
    "        rendered = render_chat_message(msg.role, msg.content)\n",
    "        if rendered is None:\n",
    "            return None\n",
    "        rendered_msgs.append(rendered)\n",
    "\n",
    "    # 3) the assistant’s output\n",
    "    #    (same logic as render_chat_message but without ToolResult)\n",
    "    out_parts: List[Dict[str, Any]] = []\n",
    "    for blk in inf.output:\n",
    "        if isinstance(blk, Text):\n",
    "            out_parts.append({\"text\": blk.text})\n",
    "        elif isinstance(blk, Thought):\n",
    "            out_parts.append({\"text\": f\"<think>{blk.text}</think>\"})\n",
    "        elif isinstance(blk, ToolCall):\n",
    "            args = blk.raw_arguments\n",
    "            if isinstance(args, str):\n",
    "                args = json.loads(args)\n",
    "            out_parts.append(\n",
    "                {\n",
    "                    \"functionCall\": {\n",
    "                        \"name\": blk.name,\n",
    "                        \"args\": args,\n",
    "                    }\n",
    "                }\n",
    "            )\n",
    "        else:\n",
    "            warnings.warn(\n",
    "                f\"Unsupported output block {type(blk)}, skipping inference.\",\n",
    "                UserWarning,\n",
    "            )\n",
    "            return None\n",
    "    rendered_msgs.append({\"role\": role_map[\"assistant\"], \"parts\": out_parts})\n",
    "\n",
    "    # 4) merge any consecutive roles and return\n",
    "    contents = merge_messages(rendered_msgs)\n",
    "    result = {\"contents\": contents}\n",
    "    if system_instruction:\n",
    "        result.update({\"systemInstruction\": system_instruction})\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = [inference_to_google(sample) for sample in train_samples]\n",
    "val_data = [inference_to_google(sample) for sample in val_samples]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Upload the training and validation datasets to GCP\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def upload_dataset_to_gcp(data: List[Dict[str, Any]], dataset_name: str, gcp_client: storage.Client) -> str:\n",
    "    with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".jsonl\", delete=False) as f:\n",
    "        # Write the openai_messages to the temporary file\n",
    "        for item in data:\n",
    "            json.dump(item, f)\n",
    "            f.write(\"\\n\")\n",
    "        f.flush()\n",
    "\n",
    "        bucket = gcp_client.bucket(os.environ[\"GCP_BUCKET_NAME\"])\n",
    "        if not bucket.exists():\n",
    "            bucket.storage_class = \"STANDARD\"\n",
    "            bucket = gcp_client.create_bucket(bucket, location=\"us\")\n",
    "            print(\n",
    "                \"Created bucket {} in {} with storage class {}\".format(\n",
    "                    bucket.name, bucket.location, bucket.storage_class\n",
    "                )\n",
    "            )\n",
    "        blob = bucket.blob(dataset_name)\n",
    "\n",
    "        generation_match_precondition = 0\n",
    "        blob.upload_from_filename(f.name, if_generation_match=generation_match_precondition)\n",
    "\n",
    "\n",
    "gcp_client = storage.Client(project=os.environ[\"GCP_PROJECT_ID\"])\n",
    "\n",
    "train_file_name = f\"train_{uuid7()}.jsonl\"\n",
    "val_file_name = f\"val_{uuid7()}.jsonl\"\n",
    "\n",
    "\n",
    "upload_dataset_to_gcp(train_data, train_file_name, gcp_client)\n",
    "upload_dataset_to_gcp(val_data, val_file_name, gcp_client)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Launch the fine-tuning job.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sft_tuning_job = sft.train(\n",
    "    source_model=MODEL_NAME,\n",
    "    train_dataset=f\"gs://{os.environ['GCP_BUCKET_NAME']}/{train_file_name}\",\n",
    "    validation_dataset=f\"gs://{os.environ['GCP_BUCKET_NAME']}/{val_file_name}\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Wait for the fine-tuning job to complete.\n",
    "\n",
    "This cell will take a while to run.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = sft.SupervisedTuningJob(sft_tuning_job.resource_name)\n",
    "while True:\n",
    "    clear_output(wait=True)\n",
    "\n",
    "    try:\n",
    "        job_state = response.state\n",
    "        print(job_state)\n",
    "        if job_state in (\n",
    "            JobState.JOB_STATE_SUCCEEDED.value,\n",
    "            JobState.JOB_STATE_FAILED.value,\n",
    "            JobState.JOB_STATE_CANCELLED.value,\n",
    "        ):\n",
    "            break\n",
    "    except Exception as e:\n",
    "        print(f\"Error: {e}\")\n",
    "    response.refresh()\n",
    "    time.sleep(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the fine-tuning job is complete, you can add the fine-tuned model to your config file.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fine_tuned_model = response.tuned_model_endpoint_name.split(\"/\")[-1]\n",
    "model_config = {\n",
    "    \"models\": {\n",
    "        fine_tuned_model: {\n",
    "            \"routing\": [\"gcp_vertex_gemini\"],\n",
    "            \"providers\": {\n",
    "                \"gcp_vertex_gemini\": {\n",
    "                    \"type\": \"gcp_vertex_gemini\",\n",
    "                    \"endpoint_id\": fine_tuned_model,\n",
    "                    \"location\": os.environ[\"GCP_LOCATION\"],\n",
    "                    \"project_id\": os.environ[\"GCP_PROJECT_ID\"],\n",
    "                }\n",
    "            },\n",
    "        }\n",
    "    }\n",
    "}\n",
    "\n",
    "print(toml.dumps(model_config))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, add a new variant to your function to use the fine-tuned model.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You're all set!\n",
    "\n",
    "You can change the weight to enable a gradual rollout of the new model.\n"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "ipynb,py:percent",
   "main_language": "python"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
